'''
-*- coding: utf-8 -*-
@File  : 01-demo.py
@Author: Shanmh
@Time  : 2024/03/06 下午4:03
@Function： openvino yolov8-cpu 推理
1座10核：Intel(R) Core(TM) i9-10900K CPU @ 3.70GHz
测试：
 1000张约 84fps


'''
import os.path
import time

from ultralytics import YOLO
from typing import Tuple
from ultralytics.utils import ops
import torch
import numpy as np
import ipywidgets as widgets
import openvino as ov
from typing import Tuple, Dict
import cv2
import numpy as np
from ultralytics.utils.plotting import colors
import random
import xml.etree.ElementTree as ET

def plot_one_box(box:np.ndarray, img:np.ndarray,
                 color:Tuple[int, int, int] = None,
                 label:str = None, line_thickness:int = 5):
    # Plots one bounding box on image img
    tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
    color = color or [random.randint(0, 255) for _ in range(3)]
    c1, c2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
    return img

def letterbox(img: np.ndarray, new_shape:Tuple[int, int] = (224, 224), color:Tuple[int, int, int] = (114, 114, 114), auto:bool = False, scale_fill:bool = False, scaleup:bool = False, stride:int = 32):

    # Resize and pad image while meeting stride-multiple constraints
    shape = img.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better test mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
    elif scale_fill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return img, ratio, (dw, dh)






class Detection:
    def __init__(self,modelpath):
        self.modelpath=self.conv_(modelpath)
        # 加载 XML 文件
        tree = ET.parse(self.modelpath)
        root = tree.getroot()
        framework = root.find('rt_info').find("framework")
        self.imgsz=eval(framework.find("imgsz").attrib['value'])
        self.label_map=eval(framework.find("names").attrib['value'])
        self.classes=len(self.label_map.keys())


    def conv_(self,modelpath,half=True):
        if modelpath.endswith(".xml"):
            return modelpath
        elif modelpath.endswith(".pt"):
            path_, name_ = os.path.split(modelpath)
            xml_path=modelpath[:-3]+ f"_openvino_model/{name_[:-3]}.xml"
            if not os.path.exists(xml_path):
                det_model = YOLO(modelpath)
                det_model.export(format="openvino", dynamic=True, half=half)

            return xml_path

    def init_(self):
        self.core = ov.Core()
        device = widgets.Dropdown(
            options=self.core.available_devices + ["AUTO"],
            value='AUTO',
            description='Device:',
            disabled=False,
        )
        self.det_ov_model = self.core.read_model(self.modelpath)
        if device.value != "CPU":
            self.det_ov_model .reshape({0: [1, 3, self.imgsz[0], self.imgsz[1]]})
        self.det_compiled_model =  self.core.compile_model(self.det_ov_model, device.value)

    def draw_results(self,results:Dict, source_image:np.ndarray, label_map:Dict):
        boxes = results["det"]
        for idx, (*xyxy, conf, lbl) in enumerate(boxes):
            label = f'{label_map[int(lbl)]} {conf:.2f}'
            source_image = plot_one_box(xyxy, source_image, label=label, color=colors(int(lbl)), line_thickness=1)
        return source_image

    def preprocess_image(self,img0: np.ndarray):
        # resize
        img = letterbox(img0,self.imgsz)[0]
        # Convert HWC to CHW
        img = img.transpose(2, 0, 1)
        img = np.ascontiguousarray(img)
        return img
    def image_to_tensor(self,image: np.ndarray):
        input_tensor = image.astype(np.float32)  # uint8 to fp32
        input_tensor /= 255.0  # 0 - 255 to 0.0 - 1.0

        # add batch dimension
        if input_tensor.ndim == 3:
            input_tensor = np.expand_dims(input_tensor, 0)
        return input_tensor
    def postprocess(self,
            pred_boxes: np.ndarray,
            input_hw: Tuple[int, int],
            orig_img: np.ndarray,
            min_conf_threshold: float = 0.25,
            nms_iou_threshold: float = 0.7,
            agnosting_nms: bool = False,
            max_detections: int = 300,
    ):
        nms_kwargs = {"agnostic": agnosting_nms, "max_det": max_detections}
        preds = ops.non_max_suppression(
            torch.from_numpy(pred_boxes),
            min_conf_threshold,
            nms_iou_threshold,
            nc=self.classes,
            **nms_kwargs
        )

        results = []
        for i, pred in enumerate(preds):
            shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape
            if not len(pred):
                results.append({"det": [], "segment": []})
                continue
            pred[:, :4] = ops.scale_boxes(input_hw, pred[:, :4], shape).round()
            results.append({"det": pred})

        return results

    def detect(self,image,draw=True):
        preprocessed_image = self.preprocess_image(image)
        input_tensor = self.image_to_tensor(preprocessed_image)
        result = self.det_compiled_model(input_tensor)
        boxes = result[self.det_compiled_model.output(0)]
        input_hw = input_tensor.shape[2:]
        detections = self.postprocess(pred_boxes=boxes, input_hw=input_hw, orig_img=image)[0]
        if draw:
            image_with_boxes = self.draw_results(detections, image, label_map=self.label_map)
            image=image_with_boxes
        return image,detections


if __name__ == '__main__':
    detecter=Detection("/models/yolov8s.pt")
    detecter.init_()
    # img=cv2.imread("/home/hxzh/Dataset/WaterDrop/dataset0311/images/train/20240311105903_water_194404.jpg")
    img=cv2.imread("")
    img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

    st_time=time.time()
    test_count=1000
    for i in range(test_count):
        show,result=detecter.detect(img.copy())
        print(result,len(result["det"]))
        print(show.shape)
        cv2.imshow("a",show)
        cv2.waitKey(1)
    print(f"calculate 100 images spend {time.time()-st_time} ; speed: {test_count/(time.time()-st_time)}/s")

